/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "FlowManager.h"
#include "FlowForwarder.h"
#include <iomanip> // setw
#include <netinet/in.h>
#include <arpa/inet.h>
#include <boost/format.hpp>

namespace aiengine {

FlowManager::FlowManager(const std::string &name):
	name_(name),
	total_process_flows_(0),
	total_timeout_flows_(0),
	timeout_(flowTimeout),
	release_flows_(true),
	flowTable_(),
	flow_cache_(nullptr),
	tcp_info_cache_(nullptr),
	protocol_(),
	lookup_flow_() {}

FlowManager::~FlowManager() {

	flush();
	flowTable_.table_.clear();
}

int64_t FlowManager::getAllocatedMemory() const {

	return ((int64_t)flowTable_.size() * (FlowCache::flowSize + sizeof(FlowNode)));
}

void FlowManager::add(const SharedPointer<Flow> &flow) {

#ifdef DEBUG
	std::cout << __FILE__ << "(" << this << "):" << __func__ << ":flow:" << flow << " total flows:" << flowTable_.size() << std::endl;
#endif

	if (SharedPointer<FlowNode> node = flowTable_.cache_->acquire(); node) {
		node->flow = flow;
		flowTable_.attach(node.get());
		flowTable_.table_[flow->getId()] = node;
		++total_process_flows_;
	}
}

void FlowManager::remove(const SharedPointer<Flow> &flow) {

#ifdef DEBUG
	std::cout << __FILE__ << "(" << this << "):" << __func__ << ":flow:" << flow << " total flows:" << flowTable_.size() << std::endl;
#endif

	if (auto it = flowTable_.table_.find(flow->getId()); it != flowTable_.table_.end()) {
		SharedPointer<FlowNode> node = (*it).second;

		flowTable_.detach(node.get());
		flowTable_.table_.erase(it);
		flowTable_.cache_->release(node);
	}
}

SharedPointer<Flow>& FlowManager::find(unsigned long hash1, unsigned long hash2) {

#ifdef DEBUG
	std::cout << __FILE__ << ":" << __func__ << " total flows:" << flowTable_.table_.size() << std::endl;
#endif
	lookup_flow_.reset();

	auto it = flowTable_.table_.find(hash1);
	if (it == flowTable_.table_.end()) {
		if (it = flowTable_.table_.find(hash2); it == flowTable_.table_.end())
			return lookup_flow_;
	}
	SharedPointer<FlowNode> node = (*it).second;

	flowTable_.detach(node.get());
	flowTable_.attach(node.get());
	lookup_flow_ = node->flow;
	return lookup_flow_;
}

#if defined(STAND_ALONE_TEST) || defined(TESTING)
void FlowManager::showFlowsByTime() {

	FlowNode *node = flowTable_.head;
	while (node) {
		
      		std::cout << __FILE__ << ":" << __func__ << ":Checking: " << *node->flow.get() <<  " lastPacketTime:" << node->flow->getLastPacketTime() << std::endl;
		node = node->next;
	}
}

#endif

void FlowManager::release(const SharedPointer<Flow> &flow) {

	// Release to their corresponding caches the attached objects
	if (auto ff = flow->forwarder.lock(); ff)
		if (ProtocolPtr l7proto = ff->getProtocol(); l7proto)
			l7proto->releaseFlowInfo(flow.get());

	// Release the TCP info if attached
	if ((tcp_info_cache_)and(flow->getProtocol() == IPPROTO_TCP))
		if (SharedPointer<TCPInfo> tcp_info = flow->getTCPInfo(); tcp_info)
			tcp_info_cache_->release(tcp_info);

	if (flow_cache_)
		flow_cache_->release(flow);
}

void FlowManager::updateTimers(const std::time_t current_time) {

#if defined(RUBY_BINDING)
	std::list<SharedPointer<Flow>> flow_list;
#endif
	int expire_flows = 0;

#ifdef DEBUG
	char mbstr[64];
	struct tm atime;
        std::strftime(mbstr, 64, "%D %X", std::localtime(&current_time));

        std::cout << __FILE__ << ":" << __func__ << "(" << name_ << "):Checking Timers at " << mbstr << " total flows:" << flowTable_.table_.size() << std::endl;

#endif
	// We check the iterator backwards because the old flows will be at the end
        FlowNode *node = flowTable_.tail;
        while (node) {
		SharedPointer<Flow> flow = node->flow;
		FlowNode *node_temp = node->prev;
#ifdef DEBUG
      		std::cout << __FILE__ << ":" << __func__ << ":Checking: " << *flow.get() <<  " lastPacketTime:" << flow->getLastPacketTime();
		std::cout << " timeout:" << timeout_ << " currentTime:" << current_time;
		//std::cout << " [ " << flow->getLastPacketTime() << " + " << timeout_ << " <= " << current_time << " ]" << std::endl;
		std::cout << " [ " << current_time << " - " << flow->getLastPacketTime();
		std::cout << " (" << current_time - flow->getLastPacketTime()  << ") > " << timeout_ << " ]" << std::endl;
#endif
		if (flow->getLastPacketTime() + timeout_ <= current_time ) {
			Flow *tmpflow = flow.get();
			++expire_flows;
			++total_timeout_flows_;
#ifdef DEBUG
      		std::cout << __FILE__ << ":EXPIRED:" << __func__ << "(" << name_ << "):Flow Expires: " << *flow.get() <<  " last pkt seen:";
		std::cout << current_time - flow->getLastPacketTime() << std::endl;
#endif


#if (defined(PYTHON_BINDING) || defined(JAVA_BINDING) || defined(LUA_BINDING) || defined(GO_BINDING))

                        if (!protocol_.expired()) {
                                ProtocolPtr proto = protocol_.lock();

                                if (proto->getDatabaseObjectIsSet())
                                        proto->databaseAdaptorRemoveHandler(tmpflow);
                        }
#endif

			if (release_flows_) { // Release the flows
				// Remove the flow from the multiindex
#ifdef DEBUG
        			std::cout << __FILE__ << ":" << __func__ << ":Flow Expires: " << *flow.get() <<  " releasing from the multi_index" <<std::endl;
#endif

#if defined(RUBY_BINDING)
				flow_list.push_front(flow);
#else
				remove(flow);
				release(flow);
#endif
			}
		} else {
			break;
		}
		node = node_temp;
	}

#if defined(RUBY_BINDING)

	// We put the flows that are gonna be remove on a list in order to prevent
	// problems with the ruby threads generated by the rb_funcall method.
	// There is an extra cost on the creation and manage of the std::list that with the other
	// compilations dont have it.

	for (auto f: flow_list) {
        	if (!protocol_.expired()) {
                	ProtocolPtr proto = protocol_.lock();

                        if (proto->getDatabaseObjectIsSet())
                        	proto->databaseAdaptorRemoveHandler(f.get());
		}
		release_flow(f);
	}

#endif

#ifdef DEBUG
        std::cout << __FILE__ << ":" << __func__ << "(" << name_ << "):Total expire flows " << expire_flows << " on table:" << flowTable_.table_.size() << std::endl;
#endif
	return;
}

void FlowManager::purge() {

	auto now = std::chrono::system_clock::now();
	std::time_t current_time = std::chrono::system_clock::to_time_t(now);

	purge(current_time);
}

void FlowManager::purge(std::time_t current_time) {

        FlowNode *node = flowTable_.tail;
        while (node) {
		SharedPointer<Flow> flow = node->flow;
		FlowNode *node_temp = node->prev;

		if (flow->getLastPacketTime() + timeout_ <= current_time ) {
			remove(flow);
                        release(flow);
			++total_timeout_flows_;
                }
                node = node_temp;
	}
}

void FlowManager::flush() {

	for (auto &item: flowTable_.table_) {
		SharedPointer<FlowNode> node = item.second;

		release(node->flow);
		flowTable_.cache_->release(node);
	}

	flowTable_.table_.clear();
	flowTable_.head = flowTable_.tail = nullptr;

	data_time_ = boost::posix_time::microsec_clock::local_time();
}

void FlowManager::flush(const std::string &protoname) {

	for (auto it = flowTable_.table_.begin(); it != flowTable_.table_.end(); ) {
		SharedPointer<Flow> flow = (*it).second->flow;
                const char *name = "None";

                if (!flow->forwarder.expired()) {
                        // Some flows could be not attached to a Protocol, for example syn packets, syn/ack packets and so on
                        SharedPointer<FlowForwarder> ff = flow->forwarder.lock();
                        ProtocolPtr proto = ff->getProtocol();
                        name = proto->name();
                }

                if (boost::iequals(protoname, name)) {
			SharedPointer<FlowNode> node = (*it).second;
			flowTable_.detach(node.get());
			it = flowTable_.table_.erase(it);

			release(flow);
			flowTable_.cache_->release(node);
		} else
			++it;
	}
}

std::ostream& operator<< (std::ostream &out, const FlowManager &fm) {

        std::string unit = "Bytes";
        uint64_t memory = fm.getAllocatedMemory();

	unitConverter(memory, unit);

        out << fm.name_ << " statistics\n"
		<< "\t" << "Timeout:                " << std::setw(10) << fm.timeout_ << "\n"
		<< "\t" << "Data time:             " << ElapsedTime().compute(fm.data_time_) << "\n"
		<< "\t" << "Total allocated:        " << std::setw(9 - unit.length()) << memory << " " << unit << "\n"
		<< "\t" << "Total process flows:    " << std::setw(10) << fm.total_process_flows_ << "\n"
		<< "\t" << "Total flows:            " << std::setw(10) << fm.flowTable_.size() << "\n"
		<< "\t" << "Total timeout flows:    " << std::setw(10) << fm.total_timeout_flows_ << std::endl;
	return out;
}

void FlowManager::print_pretty_flow(std::basic_ostream<char> &out, const Flow &flow, std::time_t current_time) const {

	std::ostringstream fivetuple;
	const char* flow_state = flow.status(current_time, timeout_);
	const char* proto_name = flow.getL7ProtocolName();

	int32_t total_bytes = flow.total_bytes[static_cast<int>(FlowDirection::FORWARD)] + flow.total_bytes[static_cast<int>(FlowDirection::BACKWARD)];
	int32_t total_packets = flow.total_packets[static_cast<int>(FlowDirection::FORWARD)] + flow.total_packets[static_cast<int>(FlowDirection::BACKWARD)];

	fivetuple << "[" << flow.getSrcAddrDotNotation() << ":" << flow.getSourcePort() << "]:" << flow.getProtocol()
		<< ":[" << flow.getDstAddrDotNotation() << ":" << flow.getDestinationPort() <<"]";

	out << boost::format("%-64s %-10d %-10d %-12s %-8s") % fivetuple.str() % total_bytes % total_packets % proto_name % flow_state << " ";

	flow.show(out);
}

void FlowManager::print_pretty_flow(Json &out, const Flow &flow) const {

	out["name"] = flow.getL7ProtocolName();

        flow.show(out);
}

void FlowManager::showFlows(int limit) const {

	showFlows(OutputManager::getInstance()->out(), limit);
}

void FlowManager::showFlows(const std::string &protoname, int limit) const {

	showFlows(OutputManager::getInstance()->out(), limit, protoname);
}

void FlowManager::showFlows(std::basic_ostream<char> &out, int limit) const {

	int current_limit = 0;

	show_flows(out, [&] (const Flow &f) {
		if (current_limit < limit) {
			++current_limit;
			return true;
		}
		return false;
	});
}

void FlowManager::showFlows(Json &out, int limit) const {

        int current_limit = 0;

        show_flows(out, [&] (const Flow &f) {
                if (current_limit < limit) {
                        ++current_limit;
                        return true;
                }
                return false;
        });
}

void FlowManager::showFlows(std::basic_ostream<char> &out, int limit, const std::string &protoname) const {

	int current_limit = 0;

	show_flows(out, [&] (const Flow &f) {
		if (!f.forwarder.expired()) {
			SharedPointer<FlowForwarder> ff = f.forwarder.lock();
			ProtocolPtr proto = ff->getProtocol();

			if (boost::iequals(protoname, proto->name())) {
				if (current_limit < limit) {
					++current_limit;
					return true;
				}
			}
		}
		return false;
	});
}

void FlowManager::showFlows(Json &out, int limit, const std::string &protoname) const {

        int current_limit = 0;

        show_flows(out, [&] (const Flow &f) {
                if (!f.forwarder.expired()) {
                        SharedPointer<FlowForwarder> ff = f.forwarder.lock();
                        ProtocolPtr proto = ff->getProtocol();

			if (boost::iequals(protoname, proto->name())) {
                                if (current_limit < limit) {
                                        ++current_limit;
                                        return true;
                                }
                        }
                }
                return false;
        });
}

void FlowManager::show_flows(std::basic_ostream<char> &out, std::function<bool (const Flow&)> condition) const {

	std::time_t current_time = std::time(nullptr);

	out << "\n"
		<< boost::format("%-64s %-10s %-10s %-12s %-8s %-12s") % "Flow" % "Bytes" % "Packets" % "L7Protocol" % "State" % "Info"
		<< "\n";

	int total_display = 0;

	// The flows are sorted by most recent activity, so the actives will be first
	// and the ones with low activity will be the last to show
	FlowNode *node = flowTable_.head;
	while (node) {
		SharedPointer<Flow> flow = node->flow;
		const Flow& cflow = *flow.get();

		if (condition(cflow)) {
			++total_display;
			print_pretty_flow(out, cflow, current_time);
			out << '\n';
		}
		node = node->next;
	}
	out << "Total " << total_display << std::endl;
}

void FlowManager::show_flows(Json &out, std::function<bool (const Flow&)> condition) const {

        // The flows are sorted by most recent activity, so the actives will be first
        // and the ones with low activity will be the last to show
	FlowNode *node = flowTable_.head;
	while (node) {
                SharedPointer<Flow> flow = node->flow;
                const Flow& cflow = *flow.get();

                if (condition(cflow)) {
			Json item;

                        print_pretty_flow(item, cflow);
			out.push_back(item);
                }
		node = node->next;
        }
}

SharedPointer<Flow> FlowManager::find(const FlowSearchOptions &fsos) {

	IPAddress addr;

	uint32_t ipsrc4 = inet_addr(fsos.ipsrc);

	if (ipsrc4 == INADDR_NONE) { // The address is IPv6
		struct in6_addr src6 = {{ .__u6_addr32 = { 0, 0, 0, 0}}};
		struct in6_addr dst6 = {{ .__u6_addr32 = { 0, 0, 0, 0}}};

		inet_pton(AF_INET6, fsos.ipsrc, &src6);
		inet_pton(AF_INET6, fsos.ipdst, &dst6);

		addr.setSourceAddress6(&src6);
		addr.setDestinationAddress6(&dst6);
       	} else {
		uint32_t ipdst4 = inet_addr(fsos.ipdst);

		addr.setSourceAddress(ipsrc4);
		addr.setDestinationAddress(ipdst4);
	}

        unsigned long hash1 = addr.getHash(fsos.portsrc, fsos.protocol, fsos.portdst);
        unsigned long hash2 = addr.getHash(fsos.portdst, fsos.protocol, fsos.portsrc);

        if (fsos.have_tag == true) {
		hash1 = hash1 ^ fsos.tag;
                hash2 = hash2 ^ fsos.tag;
        }

        auto it = flowTable_.table_.find(hash1);
        if (it == flowTable_.table_.end()) {
                if (it = flowTable_.table_.find(hash2); it == flowTable_.table_.end())
                        return nullptr;
        }
        return ((*it).second)->flow;
}

} // namespace aiengine
